{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "78cd6004",
   "metadata": {},
   "source": [
    "# 生成式预训练语言模型：理论与实战\n",
    "深蓝学院 课程 \n",
    "课程链接：https://www.shenlanxueyuan.com/course/620\n",
    "\n",
    "作者 **黄佳**"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "953bd8fc",
   "metadata": {},
   "source": [
    "## 第1步，构建一个简单的数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9944d45e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建一个玩具数据集\n",
    "corpus = [ \"我喜欢吃苹果\",\n",
    "        \"我喜欢吃香蕉\",\n",
    "        \"她喜欢吃葡萄\",\n",
    "        \"他不喜欢吃香蕉\",\n",
    "        \"他喜欢吃苹果\",\n",
    "        \"她喜欢吃草莓\"] "
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ac683369",
   "metadata": {},
   "source": [
    "## 第2步：定义一个分词函数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78ad7a39",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "单字列表:\n",
      "['我', '喜', '欢', '吃', '苹', '果']\n",
      "['我', '喜', '欢', '吃', '香', '蕉']\n",
      "['她', '喜', '欢', '吃', '葡', '萄']\n",
      "['他', '不', '喜', '欢', '吃', '香', '蕉']\n",
      "['他', '喜', '欢', '吃', '苹', '果']\n",
      "['她', '喜', '欢', '吃', '草', '莓']\n"
     ]
    }
   ],
   "source": [
    "# 定义一个分词函数，将文本转换为单字的列表\n",
    "def tokenize(text):\n",
    "    return [char for char in text]  # 将文本拆分为单字列表\n",
    "\n",
    "# 对每个文本进行分词，并打印出对应的单字列表\n",
    "print(\"单字列表:\") \n",
    "for text in corpus:\n",
    "    tokens = tokenize(text)\n",
    "    print(tokens)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "de01fd2f",
   "metadata": {},
   "source": [
    "## 第3步：计算bigram词频"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "15dea8a1",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Bigram词频:\n",
      "我: {'喜': 2}\n",
      "喜: {'欢': 6}\n",
      "欢: {'吃': 6}\n",
      "吃: {'苹': 2, '香': 2, '葡': 1, '草': 1}\n",
      "苹: {'果': 2}\n",
      "香: {'蕉': 2}\n",
      "她: {'喜': 2}\n",
      "葡: {'萄': 1}\n",
      "他: {'不': 1, '喜': 1}\n",
      "不: {'喜': 1}\n",
      "草: {'莓': 1}\n"
     ]
    }
   ],
   "source": [
    "# 定义计算N-Gram词频的函数\n",
    "from collections import defaultdict, Counter # 导入所需库\n",
    "def count_ngrams(corpus, n):\n",
    "    ngrams_count = defaultdict(Counter)  # 创建一个字典存储N-Gram计数\n",
    "    for text in corpus:  # 遍历语料库中的每个文本\n",
    "        tokens = tokenize(text)  # 对文本进行分词\n",
    "        for i in range(len(tokens) - n + 1):  # 遍历分词结果生成N-Gram\n",
    "            ngram = tuple(tokens[i:i+n])  # 创建一个N-Gram元组\n",
    "            prefix = ngram[:-1]  # 获取N-Gram的前缀\n",
    "            token = ngram[-1]  # 获取N-Gram的目标单字\n",
    "            ngrams_count[prefix][token] += 1  # 更新N-Gram计数\n",
    "    return ngrams_count\n",
    "bigram_counts = count_ngrams(corpus, 2) # 计算Bigram词频\n",
    "print(\"Bigram词频:\") # 打印Bigram词频\n",
    "for prefix, counts in bigram_counts.items():\n",
    "    print(\"{}: {}\".format(\"\".join(prefix), dict(counts)))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "28e5a7c4",
   "metadata": {},
   "source": [
    "## 第4步：计算bigram概率\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "437e7877",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "bigram概率: defaultdict(<class 'collections.Counter'>, {('我',): Counter({'喜': 1.0}), ('喜',): Counter({'欢': 1.0}), ('欢',): Counter({'吃': 1.0}), ('吃',): Counter({'苹': 0.3333333333333333, '香': 0.3333333333333333, '葡': 0.16666666666666666, '草': 0.16666666666666666}), ('苹',): Counter({'果': 1.0}), ('香',): Counter({'蕉': 1.0}), ('她',): Counter({'喜': 1.0}), ('葡',): Counter({'萄': 1.0}), ('他',): Counter({'不': 0.5, '喜': 0.5}), ('不',): Counter({'喜': 1.0}), ('草',): Counter({'莓': 1.0})})\n"
     ]
    }
   ],
   "source": [
    "# 定义计算N-Gram概率的函数\n",
    "def ngram_probabilities(ngram_counts):\n",
    "    ngram_probs = defaultdict(Counter)  # 创建一个字典存储N-Gram概率\n",
    "    for prefix, tokens_count in ngram_counts.items():  # 遍历N-Gram前缀\n",
    "        total_count = sum(tokens_count.values())  # 计算当前前缀的N-Gram计数\n",
    "        for token, count in tokens_count.items():  # 遍历每个前缀的N-Gram\n",
    "            ngram_probs[prefix][token] = count / total_count  # 计算每个N-Gram概率\n",
    "    return ngram_probs\n",
    "bigram_probs = ngram_probabilities(bigram_counts) # 计算bigram概率\n",
    "print(\"\\nbigram概率:\", bigram_probs) # 打印bigram概率"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8fdb25c7",
   "metadata": {},
   "source": [
    "## 第5步：定义生成下一个词的函数\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "161db50a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义生成下一个词的函数\n",
    "def generate_next_token(prefix, ngram_probs):\n",
    "    if not prefix in ngram_probs:  # 如果前缀不在N-Gram中，返回None\n",
    "        return None\n",
    "    next_token_probs = ngram_probs[prefix]  # 获取当前前缀对应的下一个词的概率\n",
    "    next_token = max(next_token_probs, \n",
    "                     key=next_token_probs.get)  # 选择概率最大的词作为下一个词\n",
    "    return next_token"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "8e93ab75",
   "metadata": {},
   "source": [
    "## 第6步：生成连续文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "db806c05",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义生成连续文本的函数\n",
    "def generate_text(prefix, ngram_probs, n, length=6):\n",
    "    tokens = list(prefix)  # 将前缀转换为字符列表\n",
    "    for _ in range(length - len(prefix)):  # 根据指定长度生成文本 \n",
    "        # 获取当前前缀对应的下一个词\n",
    "        next_token = generate_next_token(tuple(tokens[-(n-1):]), ngram_probs) \n",
    "        if not next_token: # 如果下一个词为None，跳出循环\n",
    "            break\n",
    "        tokens.append(next_token) # 将下一个词添加到生成的文本中\n",
    "    return \"\".join(tokens) # 将字符列表连接成字符串\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "d26e957d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "生成的文本: 我喜欢吃苹果\n"
     ]
    }
   ],
   "source": [
    "# 输入一个前缀，生成文本\n",
    "generated_text = generate_text(\"我\", bigram_probs, 2)\n",
    "print(\"\\n生成的文本:\", generated_text) # 打印生成的文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "93246657",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
